{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/Ankur3107/GitHub-Bugs-Prediction-Challenge/blob/main/model-template/MachineHack_xlnet_base.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 615
    },
    "id": "zmJTzDTonAGI",
    "outputId": "ac7bcfa0-675c-402d-d643-53df15948caa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting transformers\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/19/22/aff234f4a841f8999e68a7a94bdd4b60b4cebcfeca5d67d61cd08c9179de/transformers-3.3.1-py3-none-any.whl (1.1MB)\n",
      "\u001b[K     |████████████████████████████████| 1.1MB 3.5MB/s \n",
      "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)\n",
      "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)\n",
      "Collecting tokenizers==0.8.1.rc2\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/80/83/8b9fccb9e48eeb575ee19179e2bdde0ee9a1904f97de5f02d19016b8804f/tokenizers-0.8.1rc2-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)\n",
      "\u001b[K     |████████████████████████████████| 3.0MB 19.1MB/s \n",
      "\u001b[?25hCollecting sacremoses\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
      "\u001b[K     |████████████████████████████████| 890kB 29.7MB/s \n",
      "\u001b[?25hCollecting sentencepiece!=0.1.92\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n",
      "\u001b[K     |████████████████████████████████| 1.1MB 42.9MB/s \n",
      "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n",
      "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.6.20)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.15.0)\n",
      "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.16.0)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)\n",
      "Building wheels for collected packages: sacremoses\n",
      "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893257 sha256=434369e8d3251c915af7478b56822cc253e644d934513b9179062df8754c7c40\n",
      "  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
      "Successfully built sacremoses\n",
      "Installing collected packages: tokenizers, sacremoses, sentencepiece, transformers\n",
      "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.8.1rc2 transformers-3.3.1\n"
     ]
    }
   ],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 374
    },
    "id": "1yNkBbk1nN4Y",
    "outputId": "41a70c84-5fc4-497c-bdeb-12b9d0db006d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-10-15 12:49:35--  https://machinehack-be.s3.amazonaws.com/predict_github_issues_embold_sponsored_hackathon/Embold_Participant%27s_Dataset.zip\n",
      "Resolving machinehack-be.s3.amazonaws.com (machinehack-be.s3.amazonaws.com)... 52.219.62.32\n",
      "Connecting to machinehack-be.s3.amazonaws.com (machinehack-be.s3.amazonaws.com)|52.219.62.32|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 102320961 (98M) [application/octet-stream]\n",
      "Saving to: ‘Embold_Participant's_Dataset.zip’\n",
      "\n",
      "Embold_Participant' 100%[===================>]  97.58M  13.1MB/s    in 9.2s    \n",
      "\n",
      "2020-10-15 12:49:45 (10.7 MB/s) - ‘Embold_Participant's_Dataset.zip’ saved [102320961/102320961]\n",
      "\n",
      "Archive:  ./Embold_Participant's_Dataset.zip\n",
      "   creating: Dataset/Embold_Participant's_Dataset/\n",
      "  inflating: Dataset/Embold_Participant's_Dataset/sample submission.csv  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._sample submission.csv  \n",
      "  inflating: Dataset/Embold_Participant's_Dataset/embold_train_extra.json  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._embold_train_extra.json  \n",
      "  inflating: Dataset/Embold_Participant's_Dataset/embold_test.json  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._embold_test.json  \n",
      "  inflating: Dataset/Embold_Participant's_Dataset/embold_train.json  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._embold_train.json  \n"
     ]
    }
   ],
   "source": [
    "!wget https://machinehack-be.s3.amazonaws.com/predict_github_issues_embold_sponsored_hackathon/Embold_Participant%27s_Dataset.zip\n",
    "!unzip ./Embold_Participant\\'s_Dataset.zip -d Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "FsSUITDanP03",
    "outputId": "03386129-5e74-46fa-c900-9173a233df2f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content/Dataset/Embold_Participant's_Dataset\n"
     ]
    }
   ],
   "source": [
    "cd \"Dataset/Embold_Participant's_Dataset/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "warlGEqbnRq-"
   },
   "outputs": [],
   "source": [
    "#import pandas as pd\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "QRFbUvbrnTRp"
   },
   "outputs": [],
   "source": [
    "train_df = pd.read_json(\"embold_train.json\").reset_index(drop=True)\n",
    "test_df = pd.read_json(\"embold_test.json\").reset_index(drop=True)\n",
    "train_ex_df = pd.read_json(\"embold_train_extra.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "j38nyQB6nUh1"
   },
   "outputs": [],
   "source": [
    "test_label_df = pd.read_csv('/content/ensemble_v8_1.csv')\n",
    "test_df['label'] = test_label_df['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "Zgh7-J09naRn"
   },
   "outputs": [],
   "source": [
    "train_data = train_df.append(train_ex_df)\n",
    "test_df['text'] = test_df['title']+' '+test_df['body']\n",
    "train_data['text'] = train_data['title']+' '+train_data['body']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "XZ8EwQxhnb0q",
    "outputId": "b937fadf-023f-4d68-c95c-87155e7faa5f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(450000, 4)"
      ]
     },
     "execution_count": 8,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "Tj0t8X3OndOG"
   },
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from transformers import *\n",
    "from transformers import pipeline\n",
    "from pprint import pprint\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 765
    },
    "id": "qockFYJSne1L",
    "outputId": "a41c2f6c-a8be-4933-bad5-3194faf0f44f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on TPU  grpc://10.75.187.130:8470\n",
      "INFO:tensorflow:Initializing the TPU system: grpc://10.75.187.130:8470\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Initializing the TPU system: grpc://10.75.187.130:8470\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Clearing out eager caches\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Clearing out eager caches\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Finished initializing TPU system.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Finished initializing TPU system.\n",
      "WARNING:absl:`tf.distribute.experimental.TPUStrategy` is deprecated, please use  the non experimental symbol `tf.distribute.TPUStrategy` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Found TPU system:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Found TPU system:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores: 8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Workers: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Workers: 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "REPLICAS:  8\n"
     ]
    }
   ],
   "source": [
    "@dataclass\n",
    "class Config:\n",
    "    MAX_LEN = 320\n",
    "    BATCH_SIZE = 16  # per TPU core\n",
    "    TOTAL_STEPS = 2000  # thats approx 4 epochs\n",
    "    EVALUATE_EVERY = 200\n",
    "    LR = 1e-5\n",
    "    PRETRAINED_MODEL = \"xlnet-base-cased\"  # huggingface bert model\n",
    "\n",
    "\n",
    "transformer_flags = Config()\n",
    "\n",
    "\n",
    "def connect_to_TPU():\n",
    "    \"\"\"Detect hardware, return appropriate distribution strategy\"\"\"\n",
    "    try:\n",
    "        # TPU detection. No parameters necessary if TPU_NAME environment variable is\n",
    "        # set: this is always the case on Kaggle.\n",
    "        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
    "        print(\"Running on TPU \", tpu.master())\n",
    "    except ValueError:\n",
    "        tpu = None\n",
    "\n",
    "    if tpu:\n",
    "        tf.config.experimental_connect_to_cluster(tpu)\n",
    "        tf.tpu.experimental.initialize_tpu_system(tpu)\n",
    "        strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
    "    else:\n",
    "        # Default distribution strategy in Tensorflow. Works on CPU and single GPU.\n",
    "        strategy = tf.distribute.get_strategy()\n",
    "\n",
    "    global_batch_size = transformer_flags.BATCH_SIZE * strategy.num_replicas_in_sync\n",
    "\n",
    "    return tpu, strategy, global_batch_size\n",
    "\n",
    "\n",
    "tpu, strategy, global_batch_size = connect_to_TPU()\n",
    "print(\"REPLICAS: \", strategy.num_replicas_in_sync)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 66,
     "referenced_widgets": [
      "ca4cfce9073f4f44a4ad5bcd8bf35229",
      "2f239b4e91a44f5a9ee053be723be12a",
      "d8a6ca89e23b4a1f9e30a59ab8bdeb56",
      "6a42844a65a741578cfc707a44ba311a",
      "b609eeb7a5f445a08d98b25317a95b4d",
      "bbd6a8baec2443a3aff56a82afc60118",
      "f6205306fd6a42f6b35aa8876eb87b10",
      "94e1339d390344c09e9ac8312c0bb40c"
     ]
    },
    "id": "trvy8JEyngZO",
    "outputId": "5e534a57-59eb-4149-f504-88b7c6de3c76"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca4cfce9073f4f44a4ad5bcd8bf35229",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798011.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def regular_encode(texts, tokenizer, maxlen=512):\n",
    "    enc_di = tokenizer.batch_encode_plus(\n",
    "        texts,\n",
    "        return_attention_mask=False,\n",
    "        return_token_type_ids=False,\n",
    "        pad_to_max_length=True,\n",
    "        max_length=maxlen,\n",
    "        truncation=True,\n",
    "    )\n",
    "\n",
    "    return np.array(enc_di[\"input_ids\"])\n",
    "\n",
    "from transformers import *\n",
    "#tokenizer = AutoTokenizer.from_pretrained(transformer_flags.PRETRAINED_MODEL)\n",
    "#tokenizer = BertTokenizerFast.from_pretrained(transformer_flags.PRETRAINED_MODEL)\n",
    "tokenizer = XLNetTokenizer.from_pretrained(transformer_flags.PRETRAINED_MODEL, use_fast=True)\n",
    "#tokenizer = AlbertTokenizer.from_pretrained(transformer_flags.PRETRAINED_MODEL, use_fast=True)\n",
    "#tokenizer = RobertaTokenizerFast.from_pretrained(transformer_flags.PRETRAINED_MODEL)\n",
    "#tokenizer = ElectraTokenizerFast.from_pretrained(transformer_flags.PRETRAINED_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "XjEYPJFTnz6p"
   },
   "outputs": [],
   "source": [
    "def build_model():\n",
    "  input_layer = tf.keras.layers.Input((transformer_flags.MAX_LEN,), dtype=tf.int32)\n",
    "  bert_model = TFAutoModel.from_pretrained(transformer_flags.PRETRAINED_MODEL)\n",
    "  #bert_model = TFAutoModel.from_pretrained('/content/github-albert-xlarge-v2/')\n",
    "  #bert_model = TFAutoModel.from_pretrained('../../github-roberta-large/')\n",
    "  output_layer = bert_model(input_layer) \n",
    "  pooled_output = tf.keras.layers.GlobalAveragePooling1D()(output_layer[0])\n",
    "  output = tf.keras.layers.Dense(3, activation='softmax')(pooled_output)\n",
    "  classifier_model = tf.keras.Model(input_layer, output)\n",
    "\n",
    "  optimizer = tf.keras.optimizers.Adam(learning_rate=transformer_flags.LR)\n",
    "  classifier_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
    "  classifier_model.summary()\n",
    "    \n",
    "  return classifier_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "iXK-Ecevn1-K"
   },
   "outputs": [],
   "source": [
    "all_train = train_data[0:200000]\n",
    "all_train = all_train.append(train_data[230000:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "id": "-uVdJSfSn3lw"
   },
   "outputs": [],
   "source": [
    "question_df = all_train[all_train.label==2]\n",
    "bug_issue = all_train[all_train.label!=2]\n",
    "bug_issue = bug_issue.sample(80000)\n",
    "train = question_df.append(bug_issue)\n",
    "train = train.append(test_df)\n",
    "train = train.sample(len(train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "id": "TD4VC2PFn7T3",
    "outputId": "3393c06c-db38-4176-ee8b-402c32672d8f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1    54691\n",
       "0    52972\n",
       "2    41662\n",
       "Name: label, dtype: int64"
      ]
     },
     "execution_count": 27,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.label.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "id": "SJeGIbRkn9Ac",
    "outputId": "ea1423a5-2f12-49e0-8ad4-c2f536bbb640"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1    13846\n",
       "0    13278\n",
       "2     2876\n",
       "Name: label, dtype: int64"
      ]
     },
     "execution_count": 28,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "validate = train_data[200000:230000]\n",
    "validate.label.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 526,
     "referenced_widgets": [
      "e0e065231d6046e7ba2daad21523138e",
      "12541e3548f24981bf294aa12af21d20",
      "2a696b731e55489b92d12ceb3ddd4fbf",
      "b05b0b68d5fa4f0eb444288617cd7295",
      "af0b1a72960142c78a43da2ea457bae5",
      "712e9f0879f74f6a98f2eb89a666a17a",
      "5ec59df6b5f14583bca47c7a425eb216",
      "08145e505d5a404cad1ccdc2c7247da3",
      "29b6dd9a527d4c66ac5b1af4a089b940",
      "f0fecf29594e412b85673cccf45d57f8",
      "1323808b10aa4543994a1d87502ec0a8",
      "1b2a683a3acf47ea81b42b3347993391",
      "fd722e80c6024986bce1f218ad9ffa72",
      "c70cac1ba59a4ae680c595a9ac9ef5e5",
      "b174e06a02af44be82faea18b15dbd78",
      "4e9b031c936a49a8bd2ba2f68513f5db"
     ]
    },
    "id": "AVnTTo3Sn-WL",
    "outputId": "a3457bc5-d436-4208-91f9-2c872c03d69f"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e0e065231d6046e7ba2daad21523138e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/transformers/configuration_xlnet.py:212: FutureWarning: This config doesn't use attention memories, a core feature of XLNet. Consider setting `men_len` to a non-zero value, for example `xlnet = XLNetLMHeadModel.from_pretrained('xlnet-base-cased'', mem_len=1024)`, for accurate training performance as well as an order of magnitude faster inference. Starting from version 3.5.0, the default parameter will be 1024, following the implementation in https://arxiv.org/abs/1906.08237\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "29b6dd9a527d4c66ac5b1af4a089b940",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=565485600.0, style=ProgressStyle(descri…"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at xlnet-base-cased were not used when initializing TFXLNetModel: ['lm_loss']\n",
      "- This IS expected if you are initializing TFXLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing TFXLNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of TFXLNetModel were initialized from the model checkpoint at xlnet-base-cased.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFXLNetModel for predictions without further training.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"functional_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         [(None, 320)]             0         \n",
      "_________________________________________________________________\n",
      "tfxl_net_model (TFXLNetModel ((None, 320, 768),)       116718336 \n",
      "_________________________________________________________________\n",
      "global_average_pooling1d (Gl (None, 768)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 3)                 2307      \n",
      "=================================================================\n",
      "Total params: 116,720,643\n",
      "Trainable params: 116,720,643\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "with strategy.scope():\n",
    "    classifier_model = build_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88
    },
    "id": "Kww5N1DToAV-",
    "outputId": "2f351f55-2ed3-4a95-d652-33d80529deb4"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils_base.py:1773: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(149325, 320) (149325,)\n"
     ]
    }
   ],
   "source": [
    "# data preparation\n",
    "X_data = regular_encode(train.text.values.tolist(), tokenizer, maxlen=transformer_flags.MAX_LEN)\n",
    "y_train = train.label.values\n",
    "print(X_data.shape, y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88
    },
    "id": "wz7visPqoBsu",
    "outputId": "501cebb1-c985-4bb1-a6c3-b7bcd057a85e"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils_base.py:1773: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(30000, 320) (30000,)\n"
     ]
    }
   ],
   "source": [
    "# data preparation\n",
    "X_val = regular_encode(validate.text.values.tolist(), tokenizer, maxlen=transformer_flags.MAX_LEN)\n",
    "y_val = validate.label.values\n",
    "print(X_val.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 71
    },
    "id": "3eFtbU5zoC_e",
    "outputId": "c9cbffeb-2ca0-4f05-edd6-1954c22ef1ea"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils_base.py:1773: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
      "  FutureWarning,\n"
     ]
    }
   ],
   "source": [
    "X_test = regular_encode(test_df.text.values.tolist(), tokenizer, maxlen=transformer_flags.MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 377
    },
    "id": "_HuEu4LroEbT",
    "outputId": "91975df7-7481-406c-9c10-6a11cc05adde"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.data.Iterator.get_next_as_optional()` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.data.Iterator.get_next_as_optional()` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tfxl_net_model/transformer/mask_emb:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._0/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._1/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._2/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._3/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._4/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._5/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._6/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._7/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._8/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._9/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._10/rel_attn/seg_embed:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/r_s_bias:0', 'tfxl_net_model/transformer/layer_._11/rel_attn/seg_embed:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2/2646 [..............................] - ETA: 13:13:45 - loss: 3.0848 - accuracy: 0.3125WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0094s vs `on_train_batch_end` time: 0.3037s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0094s vs `on_train_batch_end` time: 0.3037s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2646/2646 [==============================] - ETA: 0s - loss: 0.5483 - accuracy: 0.7859WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0034s vs `on_test_batch_end` time: 0.1211s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0034s vs `on_test_batch_end` time: 0.1211s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "2646/2646 [==============================] - 976s 369ms/step - loss: 0.5483 - accuracy: 0.7859 - val_loss: 0.4949 - val_accuracy: 0.8151\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f76b282ff60>"
      ]
     },
     "execution_count": 21,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 100K datapoint\n",
    "classifier_model.fit(X_data, y_train, epochs=1,batch_size=64, validation_data=(X_val, y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "id": "ei0a5qr0zuXM",
    "outputId": "9db68449-f3e7-405d-cb6e-b5e9a9bf5c83"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2/2334 [..............................] - ETA: 12:05 - loss: 0.4916 - accuracy: 0.8203WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0121s vs `on_train_batch_end` time: 0.3048s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0121s vs `on_train_batch_end` time: 0.3048s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2334/2334 [==============================] - ETA: 0s - loss: 0.4819 - accuracy: 0.8161WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0043s vs `on_test_batch_end` time: 0.1179s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0043s vs `on_test_batch_end` time: 0.1179s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "2334/2334 [==============================] - 834s 357ms/step - loss: 0.4819 - accuracy: 0.8161 - val_loss: 0.4744 - val_accuracy: 0.8240\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f76b282f748>"
      ]
     },
     "execution_count": 30,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 80K datapoint\n",
    "classifier_model.fit(X_data, y_train, epochs=1,batch_size=64, validation_data=(X_val, y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "id": "r3vPF-25oHL6",
    "outputId": "03e8c367-a36b-4b21-bee9-191d2a40f91e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  2/235 [..............................] - ETA: 29sWARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.2209s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.2209s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "235/235 [==============================] - 53s 226ms/step\n",
      "  2/235 [..............................] - ETA: 26sWARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.2216s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.2216s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "235/235 [==============================] - 53s 226ms/step\n"
     ]
    }
   ],
   "source": [
    "y_val_pred = classifier_model.predict(X_val, batch_size=global_batch_size, verbose=1)\n",
    "y_test_pred = classifier_model.predict(X_test, batch_size=global_batch_size, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "id": "xKF-JI47pog9"
   },
   "outputs": [],
   "source": [
    "eval_df = pd.DataFrame({'X':y_val_pred.tolist(), 'y':y_val.tolist()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "id": "_1qStVjCpofV"
   },
   "outputs": [],
   "source": [
    "eval_df.to_csv('xlnet-base-maxlen320-PL-2epochs-80K-eval-prob.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "id": "VLwVLZrNpoeE"
   },
   "outputs": [],
   "source": [
    "submission = pd.read_csv('sample submission.csv')\n",
    "submission['label'] = y_test_pred.tolist()\n",
    "file_name = 'xlnet-base-maxlen320-PL-2epochs-80K-test-prob.csv'\n",
    "submission.to_csv(file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GbVAycV8p4dJ"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "MachineHack_xlnet_base.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "08145e505d5a404cad1ccdc2c7247da3": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "12541e3548f24981bf294aa12af21d20": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "1323808b10aa4543994a1d87502ec0a8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Downloading: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c70cac1ba59a4ae680c595a9ac9ef5e5",
      "max": 565485600,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_fd722e80c6024986bce1f218ad9ffa72",
      "value": 565485600
     }
    },
    "1b2a683a3acf47ea81b42b3347993391": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_4e9b031c936a49a8bd2ba2f68513f5db",
      "placeholder": "​",
      "style": "IPY_MODEL_b174e06a02af44be82faea18b15dbd78",
      "value": " 565M/565M [00:13&lt;00:00, 41.9MB/s]"
     }
    },
    "29b6dd9a527d4c66ac5b1af4a089b940": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_1323808b10aa4543994a1d87502ec0a8",
       "IPY_MODEL_1b2a683a3acf47ea81b42b3347993391"
      ],
      "layout": "IPY_MODEL_f0fecf29594e412b85673cccf45d57f8"
     }
    },
    "2a696b731e55489b92d12ceb3ddd4fbf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Downloading: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_712e9f0879f74f6a98f2eb89a666a17a",
      "max": 760,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_af0b1a72960142c78a43da2ea457bae5",
      "value": 760
     }
    },
    "2f239b4e91a44f5a9ee053be723be12a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4e9b031c936a49a8bd2ba2f68513f5db": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5ec59df6b5f14583bca47c7a425eb216": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "6a42844a65a741578cfc707a44ba311a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_94e1339d390344c09e9ac8312c0bb40c",
      "placeholder": "​",
      "style": "IPY_MODEL_f6205306fd6a42f6b35aa8876eb87b10",
      "value": " 798k/798k [00:09&lt;00:00, 82.2kB/s]"
     }
    },
    "712e9f0879f74f6a98f2eb89a666a17a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "94e1339d390344c09e9ac8312c0bb40c": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "af0b1a72960142c78a43da2ea457bae5": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "b05b0b68d5fa4f0eb444288617cd7295": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_08145e505d5a404cad1ccdc2c7247da3",
      "placeholder": "​",
      "style": "IPY_MODEL_5ec59df6b5f14583bca47c7a425eb216",
      "value": " 760/760 [00:13&lt;00:00, 55.2B/s]"
     }
    },
    "b174e06a02af44be82faea18b15dbd78": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "b609eeb7a5f445a08d98b25317a95b4d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "bbd6a8baec2443a3aff56a82afc60118": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c70cac1ba59a4ae680c595a9ac9ef5e5": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "ca4cfce9073f4f44a4ad5bcd8bf35229": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_d8a6ca89e23b4a1f9e30a59ab8bdeb56",
       "IPY_MODEL_6a42844a65a741578cfc707a44ba311a"
      ],
      "layout": "IPY_MODEL_2f239b4e91a44f5a9ee053be723be12a"
     }
    },
    "d8a6ca89e23b4a1f9e30a59ab8bdeb56": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Downloading: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_bbd6a8baec2443a3aff56a82afc60118",
      "max": 798011,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_b609eeb7a5f445a08d98b25317a95b4d",
      "value": 798011
     }
    },
    "e0e065231d6046e7ba2daad21523138e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_2a696b731e55489b92d12ceb3ddd4fbf",
       "IPY_MODEL_b05b0b68d5fa4f0eb444288617cd7295"
      ],
      "layout": "IPY_MODEL_12541e3548f24981bf294aa12af21d20"
     }
    },
    "f0fecf29594e412b85673cccf45d57f8": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f6205306fd6a42f6b35aa8876eb87b10": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "fd722e80c6024986bce1f218ad9ffa72": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
